{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "matchzoo version 2.1.0\n",
      "\n",
      "data loading ...\n",
      "data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`\n",
      "`ranking_task` initialized with metrics [normalized_discounted_cumulative_gain@3(0.0), normalized_discounted_cumulative_gain@5(0.0), mean_average_precision(0.0)]\n",
      "loading embedding ...\n",
      "embedding loaded as `glove_embedding`\n"
     ]
    }
   ],
   "source": [
    "%run ./tutorials/wikiqa/init.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from keras.backend.tensorflow_backend import set_session\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.visible_device_list=\"1\"\n",
    "config.gpu_options.allow_growth = True  # dynamically grow the memory used on the GPU\n",
    "sess = tf.Session(config=config)\n",
    "set_session(sess)  # set this TensorFlow session as the default session for Keras"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_filtered_data(preprocessor, data_type):\n",
    "    assert ( data_type in ['train', 'dev', 'test'])\n",
    "    data_pack = mz.datasets.wiki_qa.load_data(data_type, task='ranking')\n",
    "\n",
    "    if data_type == 'train':\n",
    "        X, Y = preprocessor.fit_transform(data_pack).unpack()\n",
    "    else:\n",
    "        X, Y = preprocessor.transform(data_pack).unpack()\n",
    "\n",
    "    new_idx = []\n",
    "    for i in range(Y.shape[0]):\n",
    "        if X[\"length_left\"][i] == 0 or X[\"length_right\"][i] == 0:\n",
    "            continue\n",
    "        new_idx.append(i)\n",
    "    new_idx = np.array(new_idx)\n",
    "    print(\"Removed empty data. Found \", (Y.shape[0] - new_idx.shape[0]))\n",
    "\n",
    "    for k in X.keys():\n",
    "        X[k] = X[k][new_idx]\n",
    "    Y = Y[new_idx]\n",
    "\n",
    "    pos_idx = (Y == 1)[:, 0]\n",
    "    pos_qid = X[\"id_left\"][pos_idx]\n",
    "    keep_idx_bool = np.array([ qid in pos_qid for qid in X[\"id_left\"]])\n",
    "    keep_idx = np.arange(keep_idx_bool.shape[0])\n",
    "    keep_idx = keep_idx[keep_idx_bool]\n",
    "    print(\"Removed questions with no pos label. Found \", (keep_idx_bool == 0).sum())\n",
    "\n",
    "    print(\"shuffling...\")\n",
    "    np.random.shuffle(keep_idx)\n",
    "    for k in X.keys():\n",
    "        X[k] = X[k][keep_idx]\n",
    "    Y = Y[keep_idx]\n",
    "\n",
    "    return X, Y, preprocessor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 12754.26it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:02<00:00, 6500.31it/s]\n",
      "Processing text_right with append: 100%|██████████| 18841/18841 [00:00<00:00, 1215206.55it/s]\n",
      "Building FrequencyFilter from a datapack.: 100%|██████████| 18841/18841 [00:00<00:00, 185258.28it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 184455.70it/s]\n",
      "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 922581.36it/s]\n",
      "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 1082236.12it/s]\n",
      "Building Vocabulary from a datapack.: 100%|██████████| 404432/404432 [00:00<00:00, 3795031.47it/s]\n",
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 13650.60it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:02<00:00, 6764.51it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 171037.31it/s]\n",
      "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 288623.28it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 90725.37it/s]\n",
      "Processing length_left with len: 100%|██████████| 2118/2118 [00:00<00:00, 583636.81it/s]\n",
      "Processing length_right with len: 100%|██████████| 18841/18841 [00:00<00:00, 1203693.44it/s]\n",
      "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 193145.54it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 134549.60it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Removed empty data. Found  38\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 296/296 [00:00<00:00, 14135.26it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval:   0%|          | 0/2708 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Removed questions with no pos label. Found  11672\n",
      "shuffling...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2708/2708 [00:00<00:00, 6731.87it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2708/2708 [00:00<00:00, 168473.93it/s]\n",
      "Processing text_left with transform: 100%|██████████| 296/296 [00:00<00:00, 204701.40it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2708/2708 [00:00<00:00, 159066.95it/s]\n",
      "Processing length_left with len: 100%|██████████| 296/296 [00:00<00:00, 442607.48it/s]\n",
      "Processing length_right with len: 100%|██████████| 2708/2708 [00:00<00:00, 1038699.15it/s]\n",
      "Processing text_left with transform: 100%|██████████| 296/296 [00:00<00:00, 149130.81it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2708/2708 [00:00<00:00, 140864.36it/s]\n",
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 633/633 [00:00<00:00, 12189.39it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Removed empty data. Found  2\n",
      "Removed questions with no pos label. Found  1601\n",
      "shuffling...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 5961/5961 [00:00<00:00, 7064.16it/s]\n",
      "Processing text_right with transform: 100%|██████████| 5961/5961 [00:00<00:00, 187399.25it/s]\n",
      "Processing text_left with transform: 100%|██████████| 633/633 [00:00<00:00, 259733.36it/s]\n",
      "Processing text_right with transform: 100%|██████████| 5961/5961 [00:00<00:00, 160878.23it/s]\n",
      "Processing length_left with len: 100%|██████████| 633/633 [00:00<00:00, 688714.51it/s]\n",
      "Processing length_right with len: 100%|██████████| 5961/5961 [00:00<00:00, 1166965.98it/s]\n",
      "Processing text_left with transform: 100%|██████████| 633/633 [00:00<00:00, 158526.06it/s]\n",
      "Processing text_right with transform: 100%|██████████| 5961/5961 [00:00<00:00, 137558.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Removed empty data. Found  18\n",
      "Removed questions with no pos label. Found  3805\n",
      "shuffling...\n"
     ]
    }
   ],
   "source": [
    "preprocessor = mz.preprocessors.BasicPreprocessor(fixed_length_left=20,\n",
    "                                                  fixed_length_right=40,\n",
    "                                                  remove_stop_words=False)\n",
    "train_X, train_Y, preprocessor = load_filtered_data(preprocessor, 'train')\n",
    "val_X, val_Y, _ = load_filtered_data(preprocessor, 'dev')\n",
    "pred_X, pred_Y, _ = load_filtered_data(preprocessor, 'test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "text_left (InputLayer)          (None, 20)           0                                            \n",
      "__________________________________________________________________________________________________\n",
      "text_right (InputLayer)         (None, 40)           0                                            \n",
      "__________________________________________________________________________________________________\n",
      "embedding (Embedding)           multiple             5002500     text_left[0][0]                  \n",
      "                                                                 text_right[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "dropout_1 (Dropout)             multiple             0           embedding[0][0]                  \n",
      "                                                                 embedding[1][0]                  \n",
      "                                                                 dense_1[0][0]                    \n",
      "                                                                 dense_1[1][0]                    \n",
      "                                                                 dense_2[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "lambda_1 (Lambda)               multiple             0           text_left[0][0]                  \n",
      "                                                                 text_right[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "bidirectional_1 (Bidirectional) multiple             1442400     dropout_1[0][0]                  \n",
      "                                                                 dropout_1[1][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "lambda_2 (Lambda)               (None, 20, 1)        0           lambda_1[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "lambda_3 (Lambda)               (None, 40, 1)        0           lambda_1[1][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "multiply_1 (Multiply)           (None, 20, 600)      0           bidirectional_1[0][0]            \n",
      "                                                                 lambda_2[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "multiply_2 (Multiply)           (None, 40, 600)      0           bidirectional_1[1][0]            \n",
      "                                                                 lambda_3[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "lambda_4 (Lambda)               (None, 20, 1)        0           lambda_1[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "lambda_5 (Lambda)               (None, 1, 40)        0           lambda_1[1][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "dot_1 (Dot)                     (None, 20, 40)       0           multiply_1[0][0]                 \n",
      "                                                                 multiply_2[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "multiply_3 (Multiply)           (None, 20, 40)       0           lambda_4[0][0]                   \n",
      "                                                                 lambda_5[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "permute_1 (Permute)             (None, 40, 20)       0           dot_1[0][0]                      \n",
      "                                                                 multiply_3[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "atten_mask (Lambda)             multiple             0           dot_1[0][0]                      \n",
      "                                                                 multiply_3[0][0]                 \n",
      "                                                                 permute_1[0][0]                  \n",
      "                                                                 permute_1[1][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "softmax_1 (Softmax)             multiple             0           atten_mask[0][0]                 \n",
      "                                                                 atten_mask[1][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "dot_2 (Dot)                     (None, 20, 600)      0           softmax_1[0][0]                  \n",
      "                                                                 multiply_2[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "dot_3 (Dot)                     (None, 40, 600)      0           softmax_1[1][0]                  \n",
      "                                                                 multiply_1[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "subtract_1 (Subtract)           (None, 20, 600)      0           multiply_1[0][0]                 \n",
      "                                                                 dot_2[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "multiply_4 (Multiply)           (None, 20, 600)      0           multiply_1[0][0]                 \n",
      "                                                                 dot_2[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "subtract_2 (Subtract)           (None, 40, 600)      0           multiply_2[0][0]                 \n",
      "                                                                 dot_3[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "multiply_5 (Multiply)           (None, 40, 600)      0           multiply_2[0][0]                 \n",
      "                                                                 dot_3[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "concatenate_1 (Concatenate)     (None, 20, 2400)     0           multiply_1[0][0]                 \n",
      "                                                                 dot_2[0][0]                      \n",
      "                                                                 subtract_1[0][0]                 \n",
      "                                                                 multiply_4[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "concatenate_2 (Concatenate)     (None, 40, 2400)     0           multiply_2[0][0]                 \n",
      "                                                                 dot_3[0][0]                      \n",
      "                                                                 subtract_2[0][0]                 \n",
      "                                                                 multiply_5[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "dense_1 (Dense)                 multiple             720300      concatenate_1[0][0]              \n",
      "                                                                 concatenate_2[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bidirectional_2 (Bidirectional) multiple             1442400     dropout_1[2][0]                  \n",
      "                                                                 dropout_1[3][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "lambda_6 (Lambda)               (None, 20, 1)        0           lambda_1[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "lambda_8 (Lambda)               (None, 20, 1)        0           lambda_1[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "lambda_10 (Lambda)              (None, 40, 1)        0           lambda_1[1][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "lambda_12 (Lambda)              (None, 40, 1)        0           lambda_1[1][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "multiply_6 (Multiply)           (None, 20, 600)      0           bidirectional_2[0][0]            \n",
      "                                                                 lambda_6[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "multiply_7 (Multiply)           (None, 20, 600)      0           bidirectional_2[0][0]            \n",
      "                                                                 lambda_8[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "multiply_8 (Multiply)           (None, 40, 600)      0           bidirectional_2[1][0]            \n",
      "                                                                 lambda_10[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "multiply_9 (Multiply)           (None, 40, 600)      0           bidirectional_2[1][0]            \n",
      "                                                                 lambda_12[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "lambda_7 (Lambda)               (None, 600)          0           multiply_6[0][0]                 \n",
      "                                                                 lambda_6[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "lambda_9 (Lambda)               (None, 600)          0           multiply_7[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "lambda_11 (Lambda)              (None, 600)          0           multiply_8[0][0]                 \n",
      "                                                                 lambda_10[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "lambda_13 (Lambda)              (None, 600)          0           multiply_9[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "concatenate_3 (Concatenate)     (None, 1200)         0           lambda_7[0][0]                   \n",
      "                                                                 lambda_9[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "concatenate_4 (Concatenate)     (None, 1200)         0           lambda_11[0][0]                  \n",
      "                                                                 lambda_13[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "concatenate_5 (Concatenate)     (None, 2400)         0           concatenate_3[0][0]              \n",
      "                                                                 concatenate_4[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "dense_2 (Dense)                 (None, 300)          720300      concatenate_5[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "dense_3 (Dense)                 (None, 2)            602         dropout_1[4][0]                  \n",
      "==================================================================================================\n",
      "Total params: 9,328,502\n",
      "Trainable params: 4,326,002\n",
      "Non-trainable params: 5,002,500\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "from keras.optimizers import Adam\n",
    "import matchzoo\n",
    "\n",
    "model = matchzoo.contrib.models.ESIM()\n",
    "\n",
    "# update `input_shapes` and `embedding_input_dim`\n",
    "# model.params['task'] = mz.tasks.Ranking() \n",
    "# or \n",
    "model.params['task'] = mz.tasks.Classification(num_classes=2)\n",
    "model.params.update(preprocessor.context)\n",
    "\n",
    "model.params['mask_value'] = 0\n",
    "model.params['lstm_dim'] = 300\n",
    "model.params['embedding_output_dim'] = 300\n",
    "model.params['embedding_trainable'] = False\n",
    "model.params['dropout_rate'] = 0.5\n",
    "\n",
    "model.params['mlp_num_units'] = 300\n",
    "model.params['mlp_num_layers'] = 0\n",
    "model.params['mlp_num_fan_out'] = 300\n",
    "model.params['mlp_activation_func'] = 'tanh'\n",
    "model.params['optimizer'] = Adam(lr=1e-4)\n",
    "model.guess_and_fill_missing_params()\n",
    "model.build()\n",
    "model.compile()\n",
    "model.backend.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_matrix = glove_embedding.build_matrix(preprocessor.context['vocab_unit'].state['term_index'], initializer=lambda: 0)\n",
    "model.load_embedding_matrix(embedding_matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 8650 samples, validate on 1130 samples\n",
      "Epoch 1/10\n",
      "8650/8650 [==============================] - 52s 6ms/step - loss: 0.0985 - val_loss: 0.0977\n",
      "Validation: mean_average_precision(0.0): 0.6377925262180991\n",
      "Epoch 2/10\n",
      "8650/8650 [==============================] - 52s 6ms/step - loss: 0.0947 - val_loss: 0.0939\n",
      "Validation: mean_average_precision(0.0): 0.6323746460063332\n",
      "Epoch 3/10\n",
      "8650/8650 [==============================] - 52s 6ms/step - loss: 0.0923 - val_loss: 0.0896\n",
      "Validation: mean_average_precision(0.0): 0.6447892278707743\n",
      "Epoch 4/10\n",
      "8650/8650 [==============================] - 52s 6ms/step - loss: 0.0895 - val_loss: 0.0904\n",
      "Validation: mean_average_precision(0.0): 0.6645210508066117\n",
      "Epoch 5/10\n",
      "8650/8650 [==============================] - 52s 6ms/step - loss: 0.0883 - val_loss: 0.0900\n",
      "Validation: mean_average_precision(0.0): 0.6622282952529867\n",
      "Epoch 6/10\n",
      "8650/8650 [==============================] - 52s 6ms/step - loss: 0.0839 - val_loss: 0.0900\n",
      "Validation: mean_average_precision(0.0): 0.6654279587941297\n",
      "Epoch 7/10\n",
      "8650/8650 [==============================] - 52s 6ms/step - loss: 0.0821 - val_loss: 0.0896\n",
      "Validation: mean_average_precision(0.0): 0.6668269018575894\n",
      "Epoch 8/10\n",
      "8650/8650 [==============================] - 52s 6ms/step - loss: 0.0792 - val_loss: 0.0885\n",
      "Validation: mean_average_precision(0.0): 0.6723704781393599\n",
      "Epoch 9/10\n",
      "8650/8650 [==============================] - 52s 6ms/step - loss: 0.0754 - val_loss: 0.0895\n",
      "Validation: mean_average_precision(0.0): 0.6552521148587158\n",
      "Epoch 10/10\n",
      "8650/8650 [==============================] - 52s 6ms/step - loss: 0.0731 - val_loss: 0.0910\n",
      "Validation: mean_average_precision(0.0): 0.6695447388956829\n"
     ]
    }
   ],
   "source": [
    "# train as ranking task\n",
    "model.params['task'] = mz.tasks.Ranking()\n",
    "evaluate = mz.callbacks.EvaluateAllMetrics(model,\n",
    "                                           x=pred_X,\n",
    "                                           y=pred_Y,\n",
    "                                           once_every=1,\n",
    "                                           batch_size=len(pred_Y))\n",
    "history = model.fit(x = [train_X['text_left'],\n",
    "                         train_X['text_right']],                  # (20360, 1000)\n",
    "                    y = train_Y,                                  # (20360, 2)\n",
    "                    validation_data = (val_X, val_Y),\n",
    "                    callbacks=[evaluate],\n",
    "                    batch_size = 32,\n",
    "                    epochs = 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 8650 samples, validate on 1130 samples\n",
      "Epoch 1/10\n",
      "8650/8650 [==============================] - 68s 8ms/step - loss: 0.3628 - val_loss: 0.3552\n",
      "Epoch 2/10\n",
      "8650/8650 [==============================] - 63s 7ms/step - loss: 0.3285 - val_loss: 0.3591\n",
      "Epoch 3/10\n",
      "8650/8650 [==============================] - 63s 7ms/step - loss: 0.3105 - val_loss: 0.3681\n",
      "Epoch 4/10\n",
      "8650/8650 [==============================] - 64s 7ms/step - loss: 0.3012 - val_loss: 0.3166\n",
      "Epoch 5/10\n",
      "8650/8650 [==============================] - 64s 7ms/step - loss: 0.2888 - val_loss: 0.2961\n",
      "Epoch 6/10\n",
      "8650/8650 [==============================] - 64s 7ms/step - loss: 0.2801 - val_loss: 0.3362\n",
      "Epoch 7/10\n",
      "8650/8650 [==============================] - 64s 7ms/step - loss: 0.2692 - val_loss: 0.3324\n",
      "Epoch 8/10\n",
      "8650/8650 [==============================] - 64s 7ms/step - loss: 0.2609 - val_loss: 0.3172\n",
      "Epoch 9/10\n",
      "8650/8650 [==============================] - 58s 7ms/step - loss: 0.2542 - val_loss: 0.3296\n",
      "Epoch 10/10\n",
      "8650/8650 [==============================] - 53s 6ms/step - loss: 0.2365 - val_loss: 0.3058\n"
     ]
    }
   ],
   "source": [
    "# train as classification task \n",
    "\n",
    "from keras.utils import to_categorical\n",
    "train_Y = to_categorical(train_Y)\n",
    "val_Y = to_categorical(val_Y)\n",
    "\n",
    "model.params['task'] = mz.tasks.Classification(num_classes=2)\n",
    "\n",
    "history = model.fit(x = [train_X['text_left'],\n",
    "                         train_X['text_right']],                  # (20360, 1000)\n",
    "                    y = train_Y,                                  # (20360, 2)\n",
    "                    validation_data = (val_X, val_Y),\n",
    "                    batch_size = 32,\n",
    "                    epochs = 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mz_play",
   "language": "python",
   "name": "mz_play"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
